import os
import sys
import tlsh
import pefile
import pehash
import hashlib
import pymongo
import argparse
import multiprocessing


def get_meta(file_path):
    file_name = os.path.basename(file_path)
    md5 = file_name.split("_")
    if len(md5) == 2:
        md5 = md5[1].lower()
    else:
        md5 = hashlib.md5(open(file_path, "rb").read()).hexdigest()

    meta = {
        "md5": md5,
        "pehash": None,
        "imphash": None,
        "num_imports": None,
        "compile_time": None,
        "tlsh": tlsh.hash(open(file_path, "rb").read()),
        "rich_hash": None
    }

    try:
        pe = pefile.PE(file_path)
    except Exception:
        return meta

    # Get pehash
    pe_hash = pehash.totalhash(pe=pe)
    if pe_hash is not None:
        meta["pehash"] = pe_hash.hexdigest()

    # Get imphash
    imphash = pe.get_imphash()
    if imphash:
        meta["imphash"] = imphash
    else:
        meta["imphash"] = None

    # Get number of imports
    num_imports = 0
    if hasattr(pe, "DIRECTORY_ENTRY_IMPORT"):
        for entry in pe.DIRECTORY_ENTRY_IMPORT:
            for _ in entry.imports:
                num_imports += 1
    meta["num_imports"] = num_imports

    # Get compilation timestamp
    meta["compile_time"] =  pe.FILE_HEADER.TimeDateStamp

    # Get Rich hash
    rich_header = pe.parse_rich_header()
    if rich_header is not None:
        temp_md5 = hashlib.md5()
        temp_md5.update(rich_header["clear_data"])
        meta["rich_hash"] = temp_md5.hexdigest()

    return meta


if __name__ == "__main__":

    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--malware-dir",
                        default="/media/data1/malware/")
    parser.add_argument("--db-name", default="agtr_db")
    args = parser.parse_args()

    # Re-initialize db table
    client = pymongo.MongoClient("127.0.0.1", 27017)
    db = client[args.db_name]
    collection_name = "metadata"
    collection = db[collection_name]
    if collection.count():
        print("[+] Dropping collection {}".format(collection_name))
        sys.stdout.flush()
        collection.drop()
    collection.create_index([("md5", pymongo.HASHED)])
    print("[+] Created database collection {}".format(collection_name))
    sys.stdout.flush()
    client.close()

    # Get file paths
    file_paths = []
    for root, dirs, file_names in os.walk(args.malware_dir):
        for file_name in file_names:
            file_paths.append(os.path.join(root, file_name))

    # Split file paths into batches
    batch_size = 1000
    num_files = len(file_paths)
    batches = [file_paths[i * batch_size:(i + 1) * batch_size]
               for i in range((num_files + batch_size - 1) // batch_size)]

    # Compute metadata hashes of each malware sample
    total_processed = 0
    pool = multiprocessing.Pool(12)
    for batch in batches:
        all_meta = pool.map(get_meta, batch)
        all_meta = list(filter(None, all_meta))
        if not len(all_meta):
            continue
        client = pymongo.MongoClient("127.0.0.1", 27017)
        db = client[args.db_name]
        db[collection_name].insert_many(all_meta, ordered=False)
        total_processed += len(all_meta)
        print("[-] Processed {} malware samples".format(total_processed))
        sys.stdout.flush()
        client.close()
    pool.close()
    pool.join()
